SA-NET

SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS

文章目的是减少网络的计算量。
image|690x211

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class sa_layer(nn.Module):
"""Constructs a Channel Spatial Group module.
Args:
k_size: Adaptive selection of kernel size
"""

def __init__(self, channel, groups=64):
super(sa_layer, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))

self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))

@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape

x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)

# flatten
x = x.reshape(b, -1, h, w)

return x

def forward(self, x):
b, c, h, w = x.shape

x = x.reshape(b * self.groups, -1, h, w)
x_0, x_1 = x.chunk(2, dim=1)

# channel attention
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)

# spatial attention
xs = self.gn(x_1)
xs = self.sweight * xs + self.sbias
xs = x_1 * self.sigmoid(xs)

# concatenate along channel axis
out = torch.cat([xn, xs], dim=1)
out = out.reshape(b, -1, h, w)

out = self.channel_shuffle(out, 2)
return out

与之前不同的是:

  1. 在SA使用了Group Norm在代码中看来和INSNorm是一样的效果。
  2. 首先对feature map进行group,将其分为G个group然后在每个group中进行计算。
  3. CA、SA得到mask后,使用了 $W_i \in \mathbb{R}^{C/2G\times 1 \times 1}; R_i \in \mathbb{R}^{C/2G\times 1 \times 1}$ 作为weight 和bias。这两个是可训练的参数。初始化为0,1.

感觉文章没什么创新点。